[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974vthumbe1503 wants to merge 4 commits into
Conversation
Greptile SummaryThis PR fixes DCP sync and async checkpoint loading for MXFP8/NVFP4 and other quantization recipes under FSDP2 with
Confidence Score: 5/5Safe to merge; the DCP checkpointing paths are well-scoped and the two edge cases noted are unlikely to be exercised in the FSDP2 checkpoint flow. The changes are tightly targeted at the DCP checkpointing flow. The transformer_engine/pytorch/quantized_tensor.py (new Important Files Changed
Sequence DiagramsequenceDiagram
participant FSDP2
participant QT as QuantizedTensor
participant Dispatch as __torch_dispatch__
participant Storage as Internal Buffers
Note over FSDP2,Storage: DCP Async Save
FSDP2->>QT: "tensor.to(device=cpu)"
QT->>Dispatch: aten._to_copy.default
Dispatch->>Storage: move all internal buffers to CPU
Dispatch-->>FSDP2: CPU QuantizedTensor (preserved type)
FSDP2->>QT: torch.save(cpu_tensor, ...)
QT->>QT: "__reduce_ex__ - _make_*_in_reduce_ex (fp8_dtype as int)"
FSDP2->>QT: "torch.load(..., weights_only=True)"
Note right of QT: add_safe_globals enables reconstruction
QT-->>FSDP2: QuantizedTensor restored
Note over FSDP2,Storage: DCP Sync Load - same-tensor check
FSDP2->>QT: param.untyped_storage()
QT-->>FSDP2: UntypedStorage(0 bytes)
FSDP2->>FSDP2: "storage(bf16_ckpt) == storage(fp8_param)?"
Note right of FSDP2: Always False - no false same-tensor match
Reviews (4): Last reviewed commit: "address review comment" | Re-trigger Greptile |
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
Empty storage breaks shared-storage detection in existing callers
QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
There was a problem hiding this comment.
Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | ||
| torch.testing.assert_close( | ||
| loaded_output, | ||
| ref_output, | ||
| rtol=0.125, | ||
| atol=0.25, | ||
| msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}", | ||
| ) |
There was a problem hiding this comment.
Typo: "neec" should be "need" — appears in both NVFP4 tolerance blocks.
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | |
| torch.testing.assert_close( | |
| loaded_output, | |
| ref_output, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}", | |
| ) | |
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances | |
| torch.testing.assert_close( | |
| loaded_output, | |
| ref_output, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}", | |
| ) |
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | ||
| torch.testing.assert_close( | ||
| out2, | ||
| out1, | ||
| rtol=0.125, | ||
| atol=0.25, | ||
| msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}", | ||
| ) |
There was a problem hiding this comment.
Same typo ("neec") in the second NVFP4 tolerance block for the post-training-step check.
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | |
| torch.testing.assert_close( | |
| out2, | |
| out1, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}", | |
| ) | |
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances | |
| torch.testing.assert_close( | |
| out2, | |
| out1, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}", | |
| ) |
| # NVFP4 scale unpad/repad through FSDP2 introduces small numerical | ||
| # differences vs the manual dequantize-then-allgather path. |
There was a problem hiding this comment.
Tolerance relaxed 250× for NVFP4 allgather verification
The tolerance for _check_fp8_fsdp2_allgather on NVFP4Tensor jumped from atol=5e-4, rtol=5e-3 to atol=0.125, rtol=0.25. This test compares param.dequantize() against fp32_allgathered_params[name], validating round-trip numerical fidelity of the all-gather path. A 25% relative tolerance makes the check nearly a no-op for FP4 values. A comment citing the 4-bit mantissa precision ceiling would justify the new values.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci L1 pytorch |
| msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", | ||
| ) | ||
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances |
There was a problem hiding this comment.
Why do we need dequant + quant here?
| # When a CPU copy of a quantized tensor is requested (e.g. by | ||
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. |
There was a problem hiding this comment.
This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:
- It's uncomfortable how
new_emptyandempty_likewould have different behavior. I suppose we could interpretempty_likeas "make a tensor that matches the input" andnew_emptyas "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows. - Would this affect FSDP or CPU offloading?
- Given the weirdness, would it be worthwhile raising a warning if
new_emptyis called outside of DCP?
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. | ||
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() |
There was a problem hiding this comment.
An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).
>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() | |
| target_size = size |
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
| # differences vs the manual dequantize-then-allgather path. | ||
| if isinstance(param, NVFP4Tensor): | ||
| tols = dict(atol=5e-4, rtol=5e-3) | ||
| tols = dict(atol=0.125, rtol=0.25) |
There was a problem hiding this comment.
Why are the tolerances so much bigger? Is it also due to the dequant+quant path? If so, the comment above is no longer relevant and should be replaced with a better one (but I would still like an explanation why we cannot just load the nvfp4 values from the checkpoint).
| # When a CPU copy of a quantized tensor is requested (e.g. by | ||
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. |
There was a problem hiding this comment.
Ok, I see now why you want to dequantize. I don't think this is needed though - we should be able to create the QuantlizedTensor on the CPU and save it, no? I remember that the CPU offloading of the activations faced similar problem and already had to support some CPU ops on the QuantizedTensor anyway.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
3589ffa to
4197bee
Compare
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
| # allow-listed for ``torch.load(weights_only=True)`` (used | ||
| # internally by DCP async-staging) to accept the stream. | ||
| _TE_DType, | ||
| getattr, |
There was a problem hiding this comment.
Session-wide
getattr whitelisted for weights_only=True loading
getattr is registered as a safe global at module import time. add_safe_globals is process-wide in PyTorch, so any torch.load(…, weights_only=True) call made anywhere in a session that has imported transformer_engine.pytorch — including checkpoint loads for entirely different models — now has getattr available to the pickle stream. A malicious checkpoint loaded elsewhere could use getattr to access sensitive attributes of any already-constructed object reachable from the whitelisted globals (e.g. getattr(Float8Quantizer_instance, 'amax_reduction_group') to obtain a process group, or to build callable gadget chains). The weights_only=True flag is specifically a defence against untrusted pickle payloads; adding a general-purpose reflective accessor defeats that defence.
A targeted alternative: serialize _fp8_dtype as its integer value (int(self._fp8_dtype)) and reconstruct it in _make_in_reduce_ex via TE_DType(int_value), then add TE_DType to safe globals instead of getattr. This preserves the weights_only invariant without whitelisting a reflective accessor.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Description
Fixes DCP Sync and Async checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
DCP Sync Checkpoint loading
DCP Async Checkpointing
Checklist: